#! /usr/bin/env python3
"""
--------------------------------------------------------------------------------
Copyright 2022 SiFive Inc

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
--------------------------------------------------------------------------------
"""

from __future__ import print_function
import argparse
import os
import glob
import sys
import subprocess
from junitparser import JUnitXml, TestSuite, TestCase, Skipped, Error, Failure

sys.path = [os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))] + sys.path

import generator
import inst

class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'
    NONE = ''

def bstr(color, msg):
  if sys.stdout.isatty():
    return ("%s%s%s" %(color, msg, bcolors.ENDC))
  else:
    return (msg)

def bprint(color, msg):
  if sys.stdout.isatty():
    print("%s%s%s" %(color, msg, bcolors.ENDC))
  else:
    print(msg)

class TestResult:
  def __init__(self, intrinsic_report=False):
    self.pass_list = []
    self.failed_list = []
    self.compile_failed_list = []
    self.gen_failed_list = []
    self.run_failed_list = []
    self.unknown_failed_list = []
    self.intrinsic_report = intrinsic_report

def dump_xml_report(fname: str, testsuite: TestSuite) -> None:
  print(f"Generate XML report: {fname}")
  xml = JUnitXml()
  xml.add_testsuite(testsuite)
  xml.write(fname)

def api_testing_report(stats, opts):
  logs = glob.glob("*.log")
  result = TestResult()
  test_cases = []

  for log in logs:
    testname = log.split('.')[0]

    grp_info = g.query_group_desc(testname)
    grp, subgrp = grp_info
    classname = f"{grp}.{subgrp}"

    test_case = TestCase(name = testname, classname = classname, time = 0.0)

    sz = os.path.getsize(log)
    if sz != 0:
      # passed if the number of warning equal to line number
      with open(log, 'r') as fp:
        last_line = fp.readlines()[-1]
        if last_line.count('error') == 0:
          if last_line.count('warning') != 0 and opts.warning_as_error:
            stats[grp][subgrp].failed_list.append(testname)
            result.failed_list.append(testname)
            test_case.result = [Error("Treat warning as error", "warning")]
          else:
            stats[grp][subgrp].pass_list.append(testname)
            result.pass_list.append(testname)
        else:
          stats[grp][subgrp].failed_list.append(testname)
          result.failed_list.append(testname)
          test_case.result = [Failure(last_line, "stdout/stderr")]
    else:
      stats[grp][subgrp].pass_list.append(testname)
      result.pass_list.append(testname)
    test_cases.append(test_case)
  testsuite = TestSuite(name = opts.testsuite)
  testsuite.add_testcases(test_cases)
  if opts.gen_junit_report:
    dump_xml_report(opts.junit_report_path, testsuite)
  return result

def intrinsic_report(stats, opts):
  logs = glob.glob("*.log")
  result = TestResult(intrinsic_report=True)
  test_cases = []
  for log in logs:
    testname = log.split('.')[0]

    x = testname.split('_')
    grp = None
    subgrp = None
    for i in range(1, len(x)+1):
      try:
        grp_info = g.query_group_desc("_".join(x[:i]))
        grp, subgrp = grp_info
      except:
        continue
    if grp is None or subgrp is None:
      print ("Unknown group for %s ???" % log)

    classname = f"{grp}.{subgrp}"
    # default set runtime = 0.0s
    test_case = TestCase(name = testname, classname = classname, time = 0.0)

    line = subprocess.check_output(['tail', '-1', log]).decode()
    line = line.strip()
    if line.endswith('FAIL'):
      result.failed_list.append(testname)
      stats[grp][subgrp].failed_list.append(testname)
      # Add stderr to test_case
      error_msg = str(subprocess.check_output(["grep", "-wv", "FAIL", log]).decode().strip())
      if line == 'COMPILE FAIL':
        stats[grp][subgrp].compile_failed_list.append(testname)
        result.compile_failed_list.append(testname)
        test_case.result = [Error(error_msg, "Compile error")]
      elif line == 'GEN FAIL':
        result.gen_failed_list.append(testname)
        stats[grp][subgrp].gen_failed_list.append(testname)
        test_case.result = [Error(error_msg, "Gen error")]
      elif line == 'RUN FAIL':
        result.run_failed_list.append(testname)
        stats[grp][subgrp].run_failed_list.append(testname)
        test_case.result = [Error(error_msg, "Runtime error")]
      else:
        print ("Unknown fail type for %s ???" % log)
        stats[grp][subgrp].unknown_failed_list.append(testname)
        test_case.result = [Error(error_msg,"Unknown error")]
    elif line == 'PASS':
      stats[grp][subgrp].pass_list.append(testname)
      result.pass_list.append(testname)
    else:
      result.failed_list.append(testname)
      print ("Unknown fail for %s ???" % log)
      stats[grp][subgrp].unknown_failed_list.append(testname)
      error_msg = subprocess.check_output(["grep", "^+", log]).decode().strip()
      test_case.result = [Error(error_msg, "Uncatogorized error")]
    test_cases.append(test_case)

  # Init testsuite object and dump to JUnit xml file
  testsuite = TestSuite(name = opts.testsuite)
  testsuite.add_testcases(test_cases)
  if opts.gen_junit_report:
    dump_xml_report(opts.junit_report_path, testsuite)

  return result

def intrinsic_ci_report(stats, opts):
  result = TestResult()

  total = len(opts.all)
  for i in opts.all:
    grp, subgrp = g.query_group_desc(i)
    if i in opts.failure:
      stats[grp][subgrp].failed_list.append(i)
    else:
      stats[grp][subgrp].pass_list.append(i)
      result.pass_list.append(i)
  result.failed_list = opts.failure
  return result.failed_list

def report(result, summary):
  rv = 0
  print ("-"*65)
  if summary:
    print ("API Test Summary:")
  else:
    print ("API Test Report:")
  print ("-"*65)
  for grp, subgrps in g.groups.items():
    if not summary:
      print ("%s" % grp)
    grp_npass = 0
    grp_nfail = 0
    for subgrp in subgrps:
      stat = stats[grp][subgrp]
      npass = len(stat.pass_list)
      nfail = len(stat.failed_list)
      grp_npass += npass
      grp_nfail += nfail
      ncase = npass + nfail
      color = bcolors.NONE
      if summary:
        continue
      if nfail != 0:
        color = bcolors.FAIL
      psubgrp = subgrp.replace("Functions", "").replace("Vector ", "")
      bprint (color, " - %-65s: %4s/%4s" % (psubgrp, npass, ncase))

    grp_ncase = grp_nfail + grp_npass
    color = bcolors.NONE
    if grp_nfail != 0:
      color = bcolors.FAIL
    if summary:
      bprint (color, "%-50s: %4s/%4s" % (grp, grp_npass, grp_ncase))
    else:
      bprint (color, "%s ---------" % (" " * 69))
      bprint (color, "%s %4s/%4s" % (" " * 69, grp_npass, grp_ncase))
  nfail = len(result.failed_list)
  npass = len(result.pass_list)
  ncase = nfail + npass

  color = bcolors.NONE
  if nfail != 0:
    rv = 1
    color = bcolors.FAIL
  print ("-"*65)
  bprint (color, "%-50s: %4s/%4s" % ("Total", npass, ncase))
  print ("-"*65)

  if summary:
    return rv

  if nfail:
    rv = 1
    print ("Failed list:")
    if result.intrinsic_report:
      for f in sorted(result.failed_list):
        reason = "(UNKNOWN FAIL)"
        color = bcolors.NONE
        if f in result.run_failed_list:
          reason = "(RUN FAIL)"
          color = bcolors.WARNING
        elif f in result.gen_failed_list:
          reason = "(GEN FAIL)"
          color = bcolors.FAIL
        elif f in result.compile_failed_list:
          reason = "(COMPILE FAIL)"
          color = bcolors.OKGREEN
        print (' - %s %s' % (f, bstr(color, reason)))
    else:
      print (" - ", end="")
      print ("\n - ".join(sorted(result.failed_list)))
  return rv

def parse_args(args):
  parser = argparse.ArgumentParser()
  parser.add_argument("--intrinsic-report", action="store_true", default=False)
  parser.add_argument("--gen-junit-report", action="store_true", default=False)
  parser.add_argument("--testsuite", type=str, default="")
  parser.add_argument("--junit-report-path", type=str, default="test_report.xml")
  parser.add_argument("--all", nargs = "*")
  parser.add_argument("--failure", nargs = "*", default = [])
  parser.add_argument("--warning-as-error", action="store_true", default=False)
  args, unknownargs = parser.parse_known_args()

  # Handle unknown args
  if len(unknownargs):
    print(f"[WARNING] Unknown arguments: {unknownargs}")
  # Make sure provide testsuite name when generate report
  if args.gen_junit_report:
    if args.testsuite == "":
      parser.error("[ERROR] Please provide testsuite name")

  return args

if __name__ == "__main__":
  g = generator.Grouper()
  inst.gen(g)

  stats = dict()
  for grp, subgrps in g.groups.items():
    stats[grp] = dict()
    for subgrp in subgrps:
      stats[grp][subgrp] = TestResult()

  opts = parse_args(sys.argv)
  result = None
  if opts.all:
    result = intrinsic_ci_report(stats, opts)
  elif opts.intrinsic_report:
    result = intrinsic_report(stats, opts)
  else:
    result = api_testing_report(stats, opts)

  rv = 0
  rv += report(result, summary=False)
  rv += report(result, summary=True)

  sys.exit(rv)
